{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:05.813486Z",
     "start_time": "2025-01-23T16:09:05.806587Z"
    },
    "collapsed": true,
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:27.632359Z",
     "iopub.status.busy": "2025-01-23T16:15:27.632226Z",
     "iopub.status.idle": "2025-01-23T16:15:32.555510Z",
     "shell.execute_reply": "2025-01-23T16:15:32.554982Z",
     "shell.execute_reply.started": "2025-01-23T16:15:27.632342Z"
    },
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=10, micro=14, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.0\n",
      "torch 2.5.1+cu124\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "\n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n",
    "\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a3c3b37d6bda167b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.014955Z",
     "start_time": "2025-01-23T16:09:06.008545Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.556757Z",
     "iopub.status.busy": "2025-01-23T16:15:32.556469Z",
     "iopub.status.idle": "2025-01-23T16:15:32.560836Z",
     "shell.execute_reply": "2025-01-23T16:15:32.560442Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.556739Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "length 1115394\n",
      "First Citizen:\n",
      "Before we proceed any further, hear me speak.\n",
      "\n",
      "All:\n",
      "Speak, speak.\n",
      "\n",
      "First Citizen:\n",
      "You\n"
     ]
    }
   ],
   "source": [
    "with open(\"./shakespeare.txt\", \"r\", encoding=\"utf8\") as file:\n",
    "    text = file.read()\n",
    "\n",
    "print(\"length\", len(text))\n",
    "print(text[0:100])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e8a6265b1b95715",
   "metadata": {},
   "source": [
    "## 构造字典"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed24e634656c7dfb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.027182Z",
     "start_time": "2025-01-23T16:09:06.016958Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.561432Z",
     "iopub.status.busy": "2025-01-23T16:15:32.561290Z",
     "iopub.status.idle": "2025-01-23T16:15:32.576708Z",
     "shell.execute_reply": "2025-01-23T16:15:32.575994Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.561418Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65\n",
      "['\\n', ' ', '!', '$', '&', \"'\", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n"
     ]
    }
   ],
   "source": [
    "# 去重，留下独立字符，并排序（）\n",
    "vocab = sorted(set(text))\n",
    "print(len(vocab))\n",
    "print(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4ebd525b8209c99",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.032839Z",
     "start_time": "2025-01-23T16:09:06.028186Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.577609Z",
     "iopub.status.busy": "2025-01-23T16:15:32.577375Z",
     "iopub.status.idle": "2025-01-23T16:15:32.581188Z",
     "shell.execute_reply": "2025-01-23T16:15:32.580571Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.577583Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'\\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, \"'\": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}\n"
     ]
    }
   ],
   "source": [
    "#每个字符都编好号，enumerate对每一个位置编号，生成的是列表中是元组，下面字典生成式\n",
    "char2idx = {char: idx for idx, char in enumerate(vocab)}\n",
    "print(char2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b9f0c2607c3d284a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.038647Z",
     "start_time": "2025-01-23T16:09:06.034846Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.582253Z",
     "iopub.status.busy": "2025-01-23T16:15:32.581818Z",
     "iopub.status.idle": "2025-01-23T16:15:32.585402Z",
     "shell.execute_reply": "2025-01-23T16:15:32.584878Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.582223Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n",
      " 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n",
      " 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n",
      " 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n"
     ]
    }
   ],
   "source": [
    "# 把vocab从列表变为ndarray\n",
    "idx2char = np.array(vocab)\n",
    "print(idx2char)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "493c82cc4b1ab425",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.149768Z",
     "start_time": "2025-01-23T16:09:06.039652Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.586435Z",
     "iopub.status.busy": "2025-01-23T16:15:32.586055Z",
     "iopub.status.idle": "2025-01-23T16:15:32.669754Z",
     "shell.execute_reply": "2025-01-23T16:15:32.669288Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.586410Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1115394,)\n",
      "1115394\n",
      "[18 47 56 57 58  1 15 47 58 47]\n",
      "First Citi\n"
     ]
    }
   ],
   "source": [
    "# 把字符都转换为id\n",
    "text_as_int = np.array([char2idx[c] for c in text])\n",
    "print(text_as_int.shape)\n",
    "print(len(text_as_int))\n",
    "print(text_as_int[0:10])\n",
    "print(text[0:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12e5e79707eddd0d",
   "metadata": {},
   "source": [
    "## 把莎士比亚文集分成一个一个的样本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f1afc949c019504a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.157428Z",
     "start_time": "2025-01-23T16:09:06.150774Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.672213Z",
     "iopub.status.busy": "2025-01-23T16:15:32.671866Z",
     "iopub.status.idle": "2025-01-23T16:15:32.677452Z",
     "shell.execute_reply": "2025-01-23T16:15:32.676875Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.672197Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "\n",
    "class CharDataset(Dataset):\n",
    "    def __init__(self, text_as_int, seq_length):\n",
    "        # text_as_int是字符的id列表，seq_length是每个样本的长度\n",
    "        self.sub_len = seq_length + 1  # 因为要预测下一个字符，所以要多预留一个位置\n",
    "        self.text_as_int = text_as_int\n",
    "        self.num_seq = len(text_as_int) // self.sub_len  # 样本的数量\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        # index是样本的索引，返回的是一个样本，比如第一个，就是0-100的字符,总计101个字符\n",
    "        # 返回一个长度为 self.sub_len 的序列片段。\n",
    "        return self.text_as_int[index * self.sub_len:(index + 1) * self.sub_len]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_seq\n",
    "\n",
    "\n",
    "def collate_fct(batch):\n",
    "    # batch是一个列表,列表中的每一个元素是一个样本，有101个字符，前100个是输入，后100个是输出\n",
    "    src_list = []  # 输入\n",
    "    trg_list = []  # 输出\n",
    "    for part in batch:\n",
    "        src_list.append(part[:-1])  # 输入序列\n",
    "        trg_list.append(part[1:])  # 输出序列\n",
    "\n",
    "    src_ndarray = np.array(src_list)  # 把列表转换为ndarray\n",
    "    trg_ndarray = np.array(trg_list)  # 把列表转换为ndarray\n",
    "    # 返回的是一个元组，元组中的每一个元素是一个torch.Tensor\n",
    "    return torch.Tensor(src_ndarray).to(dtype=torch.int64), torch.Tensor(trg_ndarray).to(dtype=torch.int64)\n",
    "\n",
    "\n",
    "# 每个样本的长度是101，也就是100个字符+1个结束符\n",
    "train_ds = CharDataset(text_as_int, seq_length=100)\n",
    "# shuffle=False表示不打乱顺序\n",
    "train_loader = DataLoader(train_ds, batch_size=64,\n",
    "                          collate_fn=collate_fct, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "73efe8a799e66576",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.164919Z",
     "start_time": "2025-01-23T16:09:06.159434Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.678312Z",
     "iopub.status.busy": "2025-01-23T16:15:32.678083Z",
     "iopub.status.idle": "2025-01-23T16:15:32.719147Z",
     "shell.execute_reply": "2025-01-23T16:15:32.718692Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.678288Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 100])\n",
      "torch.Size([64, 100])\n"
     ]
    }
   ],
   "source": [
    "for datas, labels in train_loader:\n",
    "    print(datas.shape)\n",
    "    print(labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a771010f6cc72a4",
   "metadata": {},
   "source": [
    "## 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f5c94a2a0ba22d84",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.170649Z",
     "start_time": "2025-01-23T16:09:06.165925Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.720073Z",
     "iopub.status.busy": "2025-01-23T16:15:32.719692Z",
     "iopub.status.idle": "2025-01-23T16:15:32.724419Z",
     "shell.execute_reply": "2025-01-23T16:15:32.723890Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.720046Z"
    }
   },
   "outputs": [],
   "source": [
    "class CharRNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=1024):\n",
    "        super().__init__()\n",
    "        # 词嵌入层\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
    "        # 循环神经网络层\n",
    "        # num_layers: 隐藏层的层数\n",
    "        # batch_first: 输入输出的格式，默认是(seq,batch,feature)，设置为True后为(batch,seq,feature)\n",
    "        # bidirectional: 是否使用双向循环神经网络\n",
    "        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)\n",
    "        # 输出层\n",
    "        self.fc = nn.Linear(hidden_dim, vocab_size)\n",
    "\n",
    "    def forward(self, x, hidden=None):\n",
    "        # [bacth_size,seq_len,vocab_size]->[batch_size,seq_len,embedding_dim]\n",
    "        # [batch_size, 100,65]->[batch_size,100,256]\n",
    "        x = self.embedding(x)  # 词嵌入\n",
    "        # [batch_size,seq_len,embedding_dim]->[batch_size,seq_len,hidden_dim]\n",
    "        # [batch_size,100,256]->[batch_size, 100,1024]\n",
    "        output, hidden = self.rnn(x, hidden)  # 循环神经网络\n",
    "        #这里和02的差异是没有只拿最后一个输出，而是把所有的输出都拿出来了\n",
    "        # [batch_size,seq_len,hidden_dim]->[batch_size,seq_len,vocab_size]\n",
    "        # [batch_size,100,1024]->[batch_size,100,65]\n",
    "        x = self.fc(output)\n",
    "        return x, hidden  # x的shape是(batch_size, seq_len, vocab_size)\n",
    "\n",
    "\n",
    "vocab_size = len(vocab)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "55f024c0091d7cc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.184668Z",
     "start_time": "2025-01-23T16:09:06.171652Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.725445Z",
     "iopub.status.busy": "2025-01-23T16:15:32.725105Z",
     "iopub.status.idle": "2025-01-23T16:15:32.738370Z",
     "shell.execute_reply": "2025-01-23T16:15:32.737928Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.725419Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 单层单向 RNN \n",
      "            embedding.weight            paramerters num: 16640\n",
      "            rnn.weight_ih_l0            paramerters num: 262144\n",
      "            rnn.weight_hh_l0            paramerters num: 1048576\n",
      "             rnn.bias_ih_l0             paramerters num: 1024\n",
      "             rnn.bias_hh_l0             paramerters num: 1024\n",
      "               fc.weight                paramerters num: 66560\n",
      "                fc.bias                 paramerters num: 65\n"
     ]
    }
   ],
   "source": [
    "print(\" 单层单向 RNN \")\n",
    "for key, value in CharRNN(vocab_size).named_parameters():\n",
    "    print(f\"{key:^40}paramerters num: {np.prod(value.shape)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ab858cede55e136",
   "metadata": {},
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9774a1ebaacafd67",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.191099Z",
     "start_time": "2025-01-23T16:09:06.185670Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.739269Z",
     "iopub.status.busy": "2025-01-23T16:15:32.738904Z",
     "iopub.status.idle": "2025-01-23T16:15:32.744089Z",
     "shell.execute_reply": "2025-01-23T16:15:32.743522Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.739243Z"
    }
   },
   "outputs": [],
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=5000, save_best_only=True):\n",
    "        self.save_dir = save_dir  # 保存路径\n",
    "        self.save_step = save_step  # 保存步数\n",
    "        self.save_best_only = save_best_only  # 是否只保存最好的模型\n",
    "        self.best_metric = -1  # 最好的指标，指标不可能为负数，所以初始化为-1\n",
    "        # 创建保存路径\n",
    "        if not os.path.exists(self.save_dir):  # 如果不存在保存路径，则创建\n",
    "            os.makedirs(self.save_dir)\n",
    "\n",
    "    # 对象被调用时：当你将对象像函数一样调用时，Python 会自动调用 __call__ 方法。\n",
    "    # state_dict() 返回模型参数的字典，包括模型参数和优化器参数\n",
    "    # metric 是指标，可以是验证集的准确率，也可以是其他指标\n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step > 0:\n",
    "            return  # 不是保存步数，则直接返回\n",
    "\n",
    "        if self.save_best_only:\n",
    "            assert metric is not None  # 必须传入metric\n",
    "            if metric >= self.best_metric:  # 如果当前指标大于最好的指标\n",
    "                # save checkpoint\n",
    "                # 保存最好的模型，覆盖之前的模型，不保存step，只保存state_dict，即模型参数，不保存优化器参数\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, \"03_text_generation.ckpt\"))\n",
    "                self.best_metric = metric  # 更新最好的指标\n",
    "        else:\n",
    "            # 保存模型\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{step}.ckpt\"))\n",
    "            # 保存每个step的模型，不覆盖之前的模型，保存step，保存state_dict，即模型参数，不保存优化器参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "16fb5bafe219f6d1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:06.199221Z",
     "start_time": "2025-01-23T16:09:06.193105Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.745060Z",
     "iopub.status.busy": "2025-01-23T16:15:32.744803Z",
     "iopub.status.idle": "2025-01-23T16:15:32.750511Z",
     "shell.execute_reply": "2025-01-23T16:15:32.749996Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.745021Z"
    }
   },
   "outputs": [],
   "source": [
    "def training(model,\n",
    "             train_loader,\n",
    "             epoch,\n",
    "             loss_fct,\n",
    "             optimizer,\n",
    "             save_ckpt_callback=None,\n",
    "             stateful=False  # 想用stateful，batch里的数据就必须连续，不能打乱\n",
    "             ):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "    }\n",
    "\n",
    "    global_step = 0  # 全局步数\n",
    "    model.train()  # 训练模式\n",
    "    hidden = None  # 隐藏层状态\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            for datas, labels in train_loader:\n",
    "                datas = datas.to(device)\n",
    "                labels = labels.to(device)\n",
    "\n",
    "                # 模型前向计算,如果数据集打乱了，stateful=False，hidden就要清空\n",
    "                # 如果数据集没有打乱，stateful=True，hidden就不需要清空\n",
    "                logits, hidden = model(datas, hidden=hidden if stateful else None)\n",
    "                # 计算损失,交叉熵损失第一个参数要是二阶张量，第二个参数要是一阶张量，所以要reshape\n",
    "                loss = loss_fct(logits.reshape(-1, vocab_size), labels.reshape(-1))\n",
    "\n",
    "                # 反向传播\n",
    "                optimizer.zero_grad()  # 梯度清零\n",
    "                loss.backward()  # 反向传播\n",
    "                optimizer.step()  # 优化器更新参数\n",
    "\n",
    "                loss = loss.cpu().item()\n",
    "\n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss,\n",
    "                    \"step\": global_step\n",
    "                })\n",
    "\n",
    "                if save_ckpt_callback is not None:\n",
    "                    # model.state_dict() 返回模型参数的字典，包括模型参数和优化器参数\n",
    "                    save_ckpt_callback(global_step, model.state_dict(), metric=-loss)\n",
    "                    # 保存最好的模型，覆盖之前的模型，保存step，保存state_dict,通过metric判断是否保存最好的模型\n",
    "\n",
    "                # 更新进度条和全局步数\n",
    "                pbar.update(1)  # 更新进度条\n",
    "                global_step += 1  # 全局步数加一\n",
    "                pbar.set_postfix({\"epoch\": epoch_id})\n",
    "\n",
    "    return record_dict  # 训练结束，返回记录字典 record_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "80c65730aa391cf6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-23T16:09:23.957375Z",
     "start_time": "2025-01-23T16:09:06.200227Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:15:32.751605Z",
     "iopub.status.busy": "2025-01-23T16:15:32.751218Z",
     "iopub.status.idle": "2025-01-23T16:17:35.040386Z",
     "shell.execute_reply": "2025-01-23T16:17:35.039963Z",
     "shell.execute_reply.started": "2025-01-23T16:15:32.751578Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 17300/17300 [02:00<00:00, 143.45it/s, epoch=99]\n"
     ]
    }
   ],
   "source": [
    "epoch = 100\n",
    "\n",
    "# 单层单边\n",
    "model = CharRNN(vocab_size=vocab_size)\n",
    "model.to(device)\n",
    "\n",
    "# 1. 定义损失函数 采用交叉熵损失\n",
    "loss_fct = nn.CrossEntropyLoss()\n",
    "\n",
    "# 2. 定义优化器\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# 3.save model checkpoint\n",
    "if not os.path.exists(\"checkpoints\"):\n",
    "    os.makedirs(\"checkpoints\")\n",
    "save_ckpt_callback = SaveCheckpointsCallback(save_dir=\"checkpoints\", save_step=len(train_loader), save_best_only=True)\n",
    "\n",
    "# 训练过程\n",
    "record_dict = training(\n",
    "    model,\n",
    "    train_loader,\n",
    "    epoch,\n",
    "    loss_fct,\n",
    "    optimizer,\n",
    "    save_ckpt_callback=save_ckpt_callback,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6aa3eaf8a1631029",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T16:17:35.041355Z",
     "iopub.status.busy": "2025-01-23T16:17:35.040894Z",
     "iopub.status.idle": "2025-01-23T16:17:35.136390Z",
     "shell.execute_reply": "2025-01-23T16:17:35.135897Z",
     "shell.execute_reply.started": "2025-01-23T16:17:35.041336Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAikAAAGdCAYAAADXIOPgAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAWVVJREFUeJzt3XlYU1f+BvA3CUkAIawCsgriggsKrqh1qbuO1S62te3Yxe46046t7c92Zlrbmep0GcdpO053u1lbO9VuVsUFrYo7uIugCC4ssoY1hOT8/khyIbJYFs3FvJ/n8Qm5uUlOvr0NL+eec65CCCFAREREJDNKRzeAiIiIqDEMKURERCRLDClEREQkSwwpREREJEsMKURERCRLDClEREQkSwwpREREJEsMKURERCRLLo5uwG9hNptx6dIleHp6QqFQOLo5RERE9BsIIVBWVobg4GAolS3vF+kQIeXSpUsICwtzdDOIiIioFc6fP4/Q0NAWP69DhBRPT08Alg+p0+na7XWNRiM2bdqEiRMnQq1Wt9vrdjSsA2sAsAY2rANrALAGNm2tg16vR1hYmPR7vKU6REixneLR6XTtHlLc3d2h0+mc/iB09jqwBqyBDevAGgCsgU171aG1QzU4cJaIiIhkiSGFiIiIZIkhhYiIiGSJIYWIiIhkiSGFiIiIZIkhhYiIiGSJIYWIiIhkiSGFiIiIZIkhhYiIiGSJIYWIiIhkiSGFiIiIZIkhhYiIiGTJqUPKJ7uz8L9MJdJyyxzdFCIiIrqCU4eU9cdysSNXifPFVY5uChEREV3BqUOK7cLRZiEc2g4iIiJqyKlDilJhiSnMKERERPLj1CHFmlHYk0JERCRDTh5S2JNCREQkV04dUpTWnhRmFCIiIvlx6pDCgbNERETy5dQhhQNniYiI5MupQ4qtK0UwpRAREcmOU4cUqSfFwe0gIiKihpw6pHBMChERkXw5dUjhmBQiIiL5cuqQUreYm2PbQURERA0xpIADZ4mIiOTIqUMKB84SERHJl1OHFA6cJSIikq82hZSlS5dCoVDg6aefbna/NWvWoFevXnB1dUW/fv2wfv36trxtu+G1e4iIiOSr1SFl//79eO+99xAbG9vsfrt378bs2bMxd+5cpKSkYObMmZg5cyaOHTvW2rduNxyTQkREJF+tCinl5eW499578cEHH8DHx6fZfZcvX47Jkydj4cKFiImJwauvvor4+Hi88847rWpwe+KYFCIiIvlyac2T5s2bh2nTpmH8+PH429/+1uy+ycnJWLBggd22SZMmYd26dU0+x2AwwGAwSPf1ej0AwGg0wmg0tqbJjRLCbHndWlO7vm5HY/vsrAFrUP/WWbEOrAHAGti0tQ5trV+LQ8rq1atx6NAh7N+//zftn5ubi8DAQLttgYGByM3NbfI5S5YsweLFixts37RpE9zd3VvW4Gbk5ykBKHHq1CmsLz3Zbq/bUSUmJjq6CQ7HGrAGNqwDawCwBjatrUNlZWWb3rdFIeX8+fN46qmnkJiYCFdX1za9cXMWLVpk1/ui1+sRFhaGiRMnQqfTtdv7bCpLRUphPnr06ImpN0W12+t2NEajEYmJiZgwYQLUarWjm+MQrAFrYMM6sAYAa2DT1jrYzoS0VotCysGDB5Gfn4/4+Hhpm8lkwo4dO/DOO+/AYDBApVLZPScoKAh5eXl22/Ly8hAUFNTk+2i1Wmi12gbb1Wp1ux4sSqVlSI5CqXTqg9CmvevbEbEGrIEN68AaAKyBTWvr0NbatWjg7Lhx43D06FGkpqZK/wYNGoR7770XqampDQIKACQkJGDLli122xITE5GQkNCmhrcHDpwlIiKSrxb1pHh6eqJv37522zp16gQ/Pz9p+5w5cxASEoIlS5YAAJ566imMHj0ab731FqZNm4bVq1fjwIEDeP/999vpI7QeF3MjIiKSr3ZfcTY7Oxs5OTnS/eHDh2PVqlV4//330b9/f3z77bdYt25dg7DjCAolF3MjIiKSq1ZNQa4vKSmp2fsAMGvWLMyaNautb9XubD0pDClERETy49TX7pHGpDClEBERyY5ThxTbsvhmZhQiIiLZceqQorRdu8exzSAiIqJGOHVIsV0FmbN7iIiI5Me5Q4r1lmNSiIiI5MepQ0rdwFkHN4SIiIgacOqQwoGzRERE8uXkIcW2LD5TChERkdw4d0ix3vJ0DxERkfw4dUiRpiAzpBAREcmOU4cUTkEmIiKSL6cOKVzMjYiISL6cOqTYcJ0UIiIi+XHqkKKUTvc4uCFERETUAEMK2JNCREQkR04dUriYGxERkXwxpIADZ4mIiOTIuUMKeLqHiIhIrpw6pHAxNyIiIvly6pBSNyaFKYWIiEhunDyk2C4wSERERHLj3CHFessxKURERPLj1CGFi7kRERHJl5OHFMstO1KIiIjkx6lDCq+CTEREJF9OHlIst4woRERE8sOQAkBwUAoREZHsOHVIUXIKMhERkWw5dUixTUHmmBQiIiL5ce6QYutJYUYhIiKSHScPKZZbhhQiIiL5ceqQUjcmhSmFiIhIbpw8pFhuObmHiIhIfpw6pHDgLBERkXw5d0jhwFkiIiLZcvKQYrnlVZCJiIjkx6lDChdzIyIiki+nDikck0JERCRfzh1SOCaFiIhItpw8pFhuGVKIiIjkp0UhZcWKFYiNjYVOp4NOp0NCQgJ++eWXJvdfuXIlFAqF3T9XV9c2N7q92NZJ4WJuRERE8uPSkp1DQ0OxdOlSdO/eHUIIfPrpp5gxYwZSUlLQp0+fRp+j0+mQlpYm3bedYpED28BZLuZGREQkPy0KKdOnT7e7//e//x0rVqzAnj17mgwpCoUCQUFBrW/hNcSBs0RERPLVopBSn8lkwpo1a1BRUYGEhIQm9ysvL0dERATMZjPi4+Px2muvNRlobAwGAwwGg3Rfr9cDAIxGI4xGY2ub3IDZbLbeinZ93Y7G9tlZA9ag/q2zYh1YA4A1sGlrHdpaP4Vo4UpmR48eRUJCAqqrq+Hh4YFVq1Zh6tSpje6bnJyM9PR0xMbGorS0FG+++SZ27NiB48ePIzQ0tMn3ePnll7F48eIG21etWgV3d/eWNLdZBwsU+Cxdhe46M+b3Mbfb6xIRERFQWVmJe+65B6WlpdDpdC1+fotDSk1NDbKzs1FaWopvv/0WH374IbZv347evXtf9blGoxExMTGYPXs2Xn311Sb3a6wnJSwsDAUFBa36kE35IfUCnvnfCQyJ8MaXDw9pt9ftaIxGIxITEzFhwgSo1WpHN8chWAPWwIZ1YA0A1sCmrXXQ6/Xw9/dvdUhp8ekejUaD6OhoAMDAgQOxf/9+LF++HO+9995Vn6tWqxEXF4eMjIxm99NqtdBqtY0+vz0PFheV5eMLhcKpD0Kb9q5vR8QasAY2rANrALAGNq2tQ1tr1+Z1Usxms12vR3NMJhOOHj2KLl26tPVt2wWv3UNERCRfLepJWbRoEaZMmYLw8HCUlZVh1apVSEpKwsaNGwEAc+bMQUhICJYsWQIAeOWVVzBs2DBER0ejpKQEb7zxBrKysvDwww+3/ydpBa44S0REJF8tCin5+fmYM2cOcnJy4OXlhdjYWGzcuBETJkwAAGRnZ0OprOucKS4uxiOPPILc3Fz4+Phg4MCB2L17928av3I91C3mRkRERHLTopDy0UcfNft4UlKS3f1ly5Zh2bJlLW7U9VK3mBtjChERkdw497V7rLcMKURERPLj3CGF53uIiIhky7lDivWW1+4hIiKSH6cOKbwKMhERkXw5dUixTUE2c0V8IiIi2XHykGK5ZT8KERGR/Dh3SIFtMTfGFCIiIrlx6pAijUlhRiEiIpIdpw4pttM9XCeFiIhIfpw6pNStOOvghhAREVEDTh1SbD0pHDpLREQkP84dUsCeFCIiIrly6pDCgbNERETy5dQhRcGrIBMREcmWk4cUyy0jChERkfw4d0ix3nIxNyIiIvlx6pBim4LMjEJERCQ/Th1SuJgbERGRfDl1SGFPChERkXw5dUjhwFkiIiL5cu6QAk5BJiIikiunDilczI2IiEi+nDqkcOAsERGRfDl5SFFcfSciIiJyCOcOKdZb9qQQERHJj1OHFE5BJiIiki+nDil1Y1Ic2w4iIiJqyKlDSl1PClMKERGR3Dh1SOFibkRERPLFkAIOnCUiIpIjpw4pHDhLREQkX04dUjgFmYiISL6cO6RwMTciIiLZcvKQYrnlFGQiIiL5ceqQwinIRERE8uXUIaVuTIpDm0FERESNcOqQoqw3JIW9KURERPLi1CGl/sBZZhQiIiJ5cfKQUvczpyETERHJi1OHFGX9nhQHtoOIiIgaalFIWbFiBWJjY6HT6aDT6ZCQkIBffvml2eesWbMGvXr1gqurK/r164f169e3qcHtqf4qKexJISIikpcWhZTQ0FAsXboUBw8exIEDB3DzzTdjxowZOH78eKP77969G7Nnz8bcuXORkpKCmTNnYubMmTh27Fi7NL6tOCaFiIhIvloUUqZPn46pU6eie/fu6NGjB/7+97/Dw8MDe/bsaXT/5cuXY/LkyVi4cCFiYmLw6quvIj4+Hu+88067NL6tFHazexzXDiIiImqo1WNSTCYTVq9ejYqKCiQkJDS6T3JyMsaPH2+3bdKkSUhOTm7t27YruynIHJVCREQkKy4tfcLRo0eRkJCA6upqeHh4YO3atejdu3ej++bm5iIwMNBuW2BgIHJzc5t9D4PBAIPBIN3X6/UAAKPRCKPR2NImN6m2trbuPWuMUCucM6jYatqete1oWAPWwIZ1YA0A1sCmrXVoa/1aHFJ69uyJ1NRUlJaW4ttvv8X999+P7du3NxlUWmPJkiVYvHhxg+2bNm2Cu7t7u72P0QzYSrBp4ya4trgaN5bExERHN8HhWAPWwIZ1YA0A1sCmtXWorKxs0/u2+NeyRqNBdHQ0AGDgwIHYv38/li9fjvfee6/BvkFBQcjLy7PblpeXh6CgoGbfY9GiRViwYIF0X6/XIywsDBMnToROp2tpk5tUUW0A9m4HAIyfMAE6N3W7vXZHYjQakZiYiAkTJkCtZg1YA+etAcA6AKwBwBrYtLUOtjMhrdXmvgOz2Wx3aqa+hIQEbNmyBU8//bS0LTExsckxLDZarRZarbbBdrVa3a4Hi8Zkln5Wubg49YEItH99OyLWgDWwYR1YA4A1sGltHdpauxaFlEWLFmHKlCkIDw9HWVkZVq1ahaSkJGzcuBEAMGfOHISEhGDJkiUAgKeeegqjR4/GW2+9hWnTpmH16tU4cOAA3n///TY1ur0oOQWZiIhItloUUvLz8zFnzhzk5OTAy8sLsbGx2LhxIyZMmAAAyM7OhlJZN2Fo+PDhWLVqFf785z/jhRdeQPfu3bFu3Tr07du3fT9FK3ExNyIiIvlqUUj56KOPmn08KSmpwbZZs2Zh1qxZLWrU9WK3TorjmkFERESNcOpr99RfcZY9KURERPLi1CEFABS2PhRmFCIiIllhSLHemhlSiIiIZIUhxZpSuCw+ERGRvDh9SLFhTwoREZG8OH1IsRXAzJRCREQkK04fUupPQyYiIiL5cPqQYsMpyERERPLi9CFFGjjLjEJERCQrDCnWW/akEBERyQtDivWWEYWIiEheGFKst4I9KURERLLCkMIxKURERLLk9CHFhsukEBERyYvThxRbTwoHzhIREcmL04cUWwGYUYiIiOTF6UOKDXtSiIiI5MXpQwqXxSciIpInhhTrLXtSiIiI5IUhxXrLjEJERCQvDCmc3UNERCRLDCnWW0YUIiIieXH6kGLDZfGJiIjkxelDSt3pHse2g4iIiOw5fUjhYm5ERETy5PQhxYYDZ4mIiOTF6UMKr4JMREQkTwwp1lsOnCUiIpIXhhTrLSMKERGRvDh9SAEXcyMiIpIlpw8pnN1DREQkT04fUmzYk0JERCQvTh9SOLuHiIhInhhSrLeCQ2eJiIhkhSHFems2O7QZREREdAWGFNvpHsc2g4iIiK7AkGK95cBZIiIieWFI4cBZIiIiWXL6kGLDZfGJiIjkxelDirSYm0NbQURERFdy+pBiwzEpRERE8tKikLJkyRIMHjwYnp6eCAgIwMyZM5GWltbsc1auXAmFQmH3z9XVtU2Nbk8KhSWcMKMQERHJS4tCyvbt2zFv3jzs2bMHiYmJMBqNmDhxIioqKpp9nk6nQ05OjvQvKyurTY1uT5zdQ0REJE8uLdl5w4YNdvdXrlyJgIAAHDx4EKNGjWryeQqFAkFBQa1r4TXG2T1ERETy1KKQcqXS0lIAgK+vb7P7lZeXIyIiAmazGfHx8XjttdfQp0+fJvc3GAwwGAzSfb1eDwAwGo0wGo1tabIdo9Eo9aTU1ta262t3JLbP7ayfH2ANANbAhnVgDQDWwKatdWhr/RSilXNvzWYzbrnlFpSUlGDnzp1N7pecnIz09HTExsaitLQUb775Jnbs2IHjx48jNDS00ee8/PLLWLx4cYPtq1atgru7e2ua26QVJ5Q4VarEvdEmDOnM7hQiIqL2UllZiXvuuQelpaXQ6XQtfn6rQ8oTTzyBX375BTt37mwybDTGaDQiJiYGs2fPxquvvtroPo31pISFhaGgoKBVH7K5ttz+7y04WaLEP27rg9viQtrttTsSo9GIxMRETJgwAWq12tHNcQjWgDWwYR1YA4A1sGlrHfR6Pfz9/VsdUlp1umf+/Pn46aefsGPHjhYFFABQq9WIi4tDRkZGk/totVpotdpGn3utDhaFUuXUByJwbevbUbAGrIEN68AaAKyBTWvr0NbatWh2jxAC8+fPx9q1a7F161ZERka2+A1NJhOOHj2KLl26tPi514JtTApXcyMiIpKXFvWkzJs3D6tWrcL3338PT09P5ObmAgC8vLzg5uYGAJgzZw5CQkKwZMkSAMArr7yCYcOGITo6GiUlJXjjjTeQlZWFhx9+uJ0/SuvYZvdwCjIREZG8tCikrFixAgAwZswYu+2ffPIJHnjgAQBAdnY2lMq6Dpri4mI88sgjyM3NhY+PDwYOHIjdu3ejd+/ebWt5O7H1pDCiEBERyUuLQspvGWOblJRkd3/ZsmVYtmxZixp1PXExNyIiInly+mv31J3ucWw7iIiIyB5Diu0H9qQQERHJCkOK9ZY9KURERPLCkCJdu4cphYiISE6cPqTYsCeFiIhIXpw+pHAKMhERkTwxpPB0DxERkSwxpFhvmVGIiIjkhSHFesvF3IiIiOSFIYWLuREREckSQ4r1VnDoLBERkaw4fUix4dkeIiIieXH6kKLk7B4iIiJZcvqQYsMxKURERPLi9CGlbp0Ux7aDiIiI7DGkWG85BZmIiEheGFKst4woRERE8sKQYr3lwFkiIiJ5YUiRFnNjSCEiIpIThhTrLTMKERGRvDh9SAGXxSciIpIlpw8pXBafiIhInhhSrLc83UNERCQvDClcFp+IiEiWGFKstxyTQkREJC8MKdZbdqQQERHJi9OHFHCdFCIiIlly+pBiKwDHpBAREcmL04cUXruHiIhInpw+pPB0DxERkTw5fUhRWPtQmFGIiIjkhSHFesspyERERPLCkGJLKRyVQkREJCsMKdZbs9mhzSAiIqIrMKRw4CwREZEsOX1I0VgrUFljcmxDiIiIyI7ThxQ3leVWX210bEOIiIjIjtOHFFcXy21Zda1jG0JERER2nD6k2HpSytiTQkREJCtOH1JcVZYBs+xJISIikpcWhZQlS5Zg8ODB8PT0REBAAGbOnIm0tLSrPm/NmjXo1asXXF1d0a9fP6xfv77VDW5vbjzdQ0REJEstCinbt2/HvHnzsGfPHiQmJsJoNGLixImoqKho8jm7d+/G7NmzMXfuXKSkpGDmzJmYOXMmjh071ubGtwfb6Z4qowlGExdLISIikguXluy8YcMGu/srV65EQEAADh48iFGjRjX6nOXLl2Py5MlYuHAhAODVV19FYmIi3nnnHfz3v/9tZbPbj6uq7ufy6lr4dNI4rjFEREQkaVFIuVJpaSkAwNfXt8l9kpOTsWDBArttkyZNwrp165p8jsFggMFgkO7r9XoAgNFohNHYfgNcjUYjVErATa1EldGMovIqeGgUV3/iDcZW0/asbUfDGrAGNqwDawCwBjZtrUNb69fqkGI2m/H0009jxIgR6Nu3b5P75ebmIjAw0G5bYGAgcnNzm3zOkiVLsHjx4gbbN23aBHd399Y2uUlqmFAFBX7ZnIQwj3Z/+Q4jMTHR0U1wONaANbBhHVgDgDWwaW0dKisr2/S+rQ4p8+bNw7Fjx7Bz5842NaAxixYtsut90ev1CAsLw8SJE6HT6drtfYxGIxITE+Hv1Qn6gkr0HzQMw6Ka7hW6UdnqMGHCBKjVakc3xyFYA9bAhnVgDQDWwKatdbCdCWmtVoWU+fPn46effsKOHTsQGhra7L5BQUHIy8uz25aXl4egoKAmn6PVaqHVahtsV6vV1+Rg8XS1vGZlrXDqg/Fa1bcjYQ1YAxvWgTUAWAOb1tahrbVr0eweIQTmz5+PtWvXYuvWrYiMjLzqcxISErBlyxa7bYmJiUhISGhZS68hT+uys5yGTEREJB8t6kmZN28eVq1ahe+//x6enp7SuBIvLy+4ubkBAObMmYOQkBAsWbIEAPDUU09h9OjReOuttzBt2jSsXr0aBw4cwPvvv9/OH6X1PLW2kOLcA6SIiIjkpEU9KStWrEBpaSnGjBmDLl26SP++/vpraZ/s7Gzk5ORI94cPH45Vq1bh/fffR//+/fHtt99i3bp1zQ62vd7Yk0JERCQ/LepJEUJcdZ+kpKQG22bNmoVZs2a15K2uq7qQwp4UIiIiuXD6a/cAgIeWPSlERERyw5ACQOdmGX3MkEJERCQfDCmoGzir5+keIiIi2WBIAQfOEhERyRFDCjhwloiISI4YUlA3cLa0ij0pREREcsGQAiDc1w1KBVBQbkBuabWjm0NERERgSAFguXZPn2AvAMDezEIHt4aIiIgAhhSJ7erHe84ypBAREckBQ4rVsCg/AMCes0UObgkREREBDCmSQV19oVQAmQUVHJdCREQkAwwpVl5uanTr7AEAOJ1X5uDWEBEREUNKPUFergCA/DKDg1tCREREDCn1BOosISVPz9M9REREjsaQUk+ApxYAkM+QQkRE5HAMKfXU9aTwdA8REZGjMaTUE6iz9KTklbEnhYiIyNEYUuoJsPak5LMnhYiIyOEYUuqxne7JL6uGEMLBrSEiInJuDCn1dPawnO4xmgSKK40Obg0REZFzY0ipR+OihF8nDQBOQyYiInI0hpQrdLZOQ2ZIISIiciyGlCvUjUvh4FkiIiJHYki5gjQNmRcZJCIiciiGlCuEeLsDALKLKh3cEiIiIufGkHKF6ADLlZDT88sd3BIiIiLnxpByhe6BlpCSkV/OtVKIiIgciCHlCl39OkGlVKDcUItczvAhIiJyGIaUK2hclOjqZxmXkp7HUz5ERESOwpDSiO4BngA4LoWIiMiRGFIaUTcupczBLSEiInJeDCmNsM3wOX5J7+CWEBEROS+GlEYMjPCBi1KBIxdKseFYrqObQ0RE5JQYUhoR6uOOR0dFAQBe+uEYqo0mB7eIiIjI+TCkNOGP47rDt5MGeXoDT/sQERE5AENKE1zVKvQKsszyOXuZs3yIiIiuN4aUZnTrbBlAe7agwsEtISIicj4MKc2I6twJAHCG66UQERFddwwpzYhiTwoREZHDMKQ0o5u1JyWrsAK1JrODW0NERORcWhxSduzYgenTpyM4OBgKhQLr1q1rdv+kpCQoFIoG/3Jz5b/+SLCXG1zVShhNAnGvJmJZ4mleGZmIiOg6aXFIqaioQP/+/fHuu++26HlpaWnIycmR/gUEBLT0ra87pVIBD60LAKCsuhbLt6Tjv9vPOrhVREREzsGlpU+YMmUKpkyZ0uI3CggIgLe3d4uf52jDovzw05Ec6f4/NpxCoE6L2+JDHdgqIiKiG1+LQ0prDRgwAAaDAX379sXLL7+MESNGNLmvwWCAwWCQ7uv1lsXUjEYjjEZju7XJ9lrNveYTo7rC112NuSMi8NmebHy0KwvPfXsE3fzd0LuLrt3a4ki/pQ43OtaANbBhHVgDgDWwaWsd2lo/hWjDIAuFQoG1a9di5syZTe6TlpaGpKQkDBo0CAaDAR9++CE+//xz7N27F/Hx8Y0+5+WXX8bixYsbbF+1ahXc3d1b29w2MwvgvZNKnCpVYnKoGVPCOJiWiIioKZWVlbjnnntQWloKna7lf9hf85DSmNGjRyM8PByff/55o4831pMSFhaGgoKCVn3IphiNRiQmJmLChAlQq9W/6Tlf7M3G4p9OYUhXHxRW1MDbTY2vHh4MhULRbu263lpThxsNa8Aa2LAOrAHAGti0tQ56vR7+/v6tDinX7XRPfUOGDMHOnTubfFyr1UKr1TbYrlarr8nB0pLXjY/wAwDsO1csbTt9uQp9Q7zavV3X27Wqb0fCGrAGNqwDawCwBjatrUNba+eQdVJSU1PRpUsXR7x1m8V00UHjYl+2LSfzHdQaIiKiG1eLe1LKy8uRkZEh3c/MzERqaip8fX0RHh6ORYsW4eLFi/jss88AAP/6178QGRmJPn36oLq6Gh9++CG2bt2KTZs2td+nuI40Lkr0CdYhJbtE2rb1VB6eGt9dui+EwIXiKoT6uHXo00BERESO1OKelAMHDiAuLg5xcXEAgAULFiAuLg5//etfAQA5OTnIzs6W9q+pqcEzzzyDfv36YfTo0Th8+DA2b96McePGtdNHuP4GhHkDADyta6gcvlCK/LJq6fGPdmbipte34duDFxzRPCIiohtCi3tSxowZ0+yqqytXrrS7/9xzz+G5555rccPkbHKfIKzcfQ73DotA8pkCHL5QiqRTl3Hn4DAAwN9+PgkAWPjtEcwaFObIphIREXVYvHZPKwyN8kPKXyZg4aSeuLlXIABg88k8AICh1iTt59dJAwCoMNQiKS0fZjOX1CciIvqtGFJaydtdA5VSgXExluX9d2YUoNpowpELpdI+SqVlPMrSX07hgU/249tDPP1DRET0WzGktFGfYB0CdVpU1piwN7MI+zKLpMculxlQVWPC9tOXAcDuMSIiImoeQ0obKRQK6ZTPJ7sypUBiczCrGNlFlQCAE5f01719REREHRVDSjv4/bAIaFyUSEq7jH2ZRVAoADe1CgDwXUrdKZ70/DLU1HIpfSIiot+CIaUd9A7WYcmt/QAAGpUS/7prAMb26gwA+O7QRWk/o0kgI7/cIW0kIiLqaByyLP6N6PaBoQjzdYefhwbdOnvgZE6Z3eMqpQIms8DJHD16B+tQVm3Ef5LOYMaAYPQKujGupkxERNSe2JPSjoZE+qJbZw8AQLhv3dWaNSolpvQNAgCcyLGMS3nmm8NYkXQG81elXP+GEhERdQAMKddIdICH9PO/Zw/A6B6W0z//O3QBG4/nYtMJy7oqGfnl+PHwJTy1OgXlhtoGr1NUUYNaE8exEBGR8+HpnmtkcFcfvDS9N/qGeGFwV19U1tTis+QsHL1Yisc+P2i374trj0JfXYu4MG88MCISAFBVY8LSX07i8z1ZGB8TiPfnDHLExyAiInIY9qRcIwqFAg+OiMTgrr4AAHeNC9Y8noA5CREN9tVXW3pQfjmWK217feMpfJqcBbMANp3Ig8m6Wm1ppREPrdyPxz8/yBVsiYjohsaelOvIVa3CKzP64s5BYTh2sRTb0vKx8Xie9Pi+c0W4XGaAp6sL/nfFxQkzC8oRoHPF3R/swUnruJbsokp09e90XT8DERHR9cKQ4gB9Q7zQN8QLBeUGu5AiBLDheC683NTQV9cixNsN/p5aHD5fguOX9NhwLFcKKIBlEC5DChER3ah4useBYrrUTT2+uZflGkAf7DiLT3efA2CZ1twvxLLP4fOl+GJPNgBA62L5z1Y/sBAREd1oGFIcqF+IF1yUCnhoXfDmrP7o7KlFdlElDmYVQ+uixJ2DQtE32AsA8PGuTOTqq+HXSYNnJ/YEABy/pMfh8yWoaGRWEBERUUfH0z0OFKBzxWcPDYG71gW+nTR4ZkIP/N93RwEA/7xzAEJ93NHHGlJsZg8Jx4BwbwDA1lP52HoqH+G+7lhxX3yDfYmIiDoyhhQHGx7tL/08a1AY8vQGRPi5Y1psFwBAj6C69Vb6h3rh8THdIIT9rJ7sokr8/qN9eO/3A/Gvzacxf2x3JHTz+81tMNSa8Z8TSuysOY43Zg1o2wciIiJqJwwpMqJSKvDU+O5227QuKrx2az9kFpTjmYk94Wq9cGF9Xf3cca6wEvd+uBc1tWZk5JcjccFo6FzVv+l9U8+XIK1UibSDF/HKjH5w0zR8DyIiouuNY1I6gHuGhuPFab3tAsoDw7vCRanAJw8OxpNjowFAusJynt6ANzem/ebXT8kukX6+UFzZPo0mIiJqI4aUDuovv+uNfS+Ox9ieAZg5IAQh3m4AgOn9gwEAX+7Nxs70Ary5MQ3ni5oPHofOl0g/Z19lXyIiouuFp3s6KJVSAd9OGgCAxkWJjx8YjMMXSnBHfChKq4zYcfoy7vtoLwDgmwPn8eXDQ9E90LPB6wghkHq+VLpfP6TYVrRVKhXX8qMQERE1ij0pN4ieQZ64c1AYlEoF/nhztN1j+WUG/Omb1AbPqTaasObgBRRXGqVtWYWWkGIyC9z/yT4MW7IFxRU117TtREREjWFIuQEN6uqL2+JC0NXPHd8+ngAXpQLHLupx5nI5AKCyphZCCPzp61Q89+0Ru+emZBfj1Z9OYNF3R/BregHyywzYfaYQVTUmLPgmFY99fkC6jhAREclfVY0Jxy+VXn1HGeLpnhvUP+8aIP08srs/ktIu48W1R1FVY8LhC6W4uVcAtqXlAwAi/dwR7FKOXXlKHL5QisMX7A/mfZmF+HT3Oew7VwQAOJWrl9ZkEULAUGtudNYRERFdXxn55SitqsHACF9p2wtrj2JtykV8+tAQxIZ44XReGdw0KsSGejuuob8Re1KcwNR+ljVX9pwtkgLI1lP5EAK4qbs/Nj09EjcHm5t8/hd7s6WAAgDpeZYemYz8Mty+YjdiX97UYVM6EdGNwmQWuOeDPbjzvT04V1CBSyVVKKmswc9HcwAAG47l4Oa3knDX+3twyzu78OPhSw5u8dUxpDiBSb2DoFZZBr9O7ReEB0d0lR6bk2D52UdTt38njQqZS6Yi8U+jAKDB6Z3TeWXIyC/Hre/uxqHsEtSYzNh6Mv+afgYiImrekQslyC8zwGQW+GfiaYz8x1YMeCVRWp7iu0MX7cYg/vX7Y7hUUuWo5v4mPN3jBLzc1fj33XG4WFKFB0dEotZsxtnLFVApFbi5VwDMplqo6sXVyX27QKFQoFtnD3hqXVBmvTbQ3JGR+GhnJg5fKMGmE3nSdgA4fKEUF0uqkHm5Ap09tegZZD+TqNZkhouKmZiIqLXyy6qhVang5d74Qp27Mgqkn39opJfEYA0rt8aF4FRuGU7m6HHT69vw+2EReGl6bygU8pvJyZDiJKZYT/kAgEqpwqcPDZHum02W25d+1wtb0wrw52kxACxTj2PDvLAroxDx4d4Y1ysAH+3MxK6MQgBAgKcWr8zog8e/OITNJ/Ow+0wBKmtMUCiAn/4wEn2CvSCEwGvrT+KTXefwzj1xmNy3rh1ERPTbZOSXY+a7uwAAL03vjTx9NW7pH4K8smqczNHjjoGh+DW9oMnnKxSA7YoqI6P9MW9sNJ5ZcxiHz5dg5e5ziOrcSepZlxOGFJLcNzQcD47sZrdtWr9g7MooxNyRUQ3WWfnDuO4Y0zMAKqUCJrNAZY0l7QgBbD6Rjz7BXvjbzyfx0c5MAMDrG9MwsXeQtO7K8UuliPL3QE2tGetSL6K0yojb4kMQ6uN+HT4tEVH72X+uCKv2ZiOzoAJ/m9kXfUO8kHgiD58ln8Nrt/ZDmO/Vv9eMJjMyCypgFgI9Az3x4Mr92JdZhGBvN5jNAuXW3uuF1lmZ7+84i3JDLcwCeHNjGvTVlse1LkoYas0I9XHDtNguMJsFUs+XYP+5YgDA0ChfhPq44/t5I/Dhr2fxt59P4m8/nURcmA/6hcrrQrUMKdSs2UPCcGtcCNw0qgYXNrwlNhiuahV6BHriZI4eADC4qw/2nyvGrowCTO0XhI93WQKKm1qFs5crsOlELib37YIv92bhxbXHMDTSF0UVNUjPtwzGPZRdjLE9A7ArowCv3xELb3cNiIgcIb8KGL9sJx4ZFYX7hkXYPWY2C5zKLUNmQQXG9OyMRz87II33eHHtUaybNwKPfX4AZgE88tkBbHh6VLPvlZJdjKdWp0oLaj48MhJJaZcBWHpRAMDT1QXDovywM70A/p4anC+qkrbbAkpXP3f0C/XGj4cv4da4EDwzsScA4NWfTmD/uWKE+rjZ/SE4d2Qk9mYW4ciFEtSYTG0tWbtjSKFmKRQK6YKD9c9X+ntopPOinepdkPDvt/bDxGU7cCi7GP/YcApCAJP6BKJHoCfe3pqBF9YeQ3ZRJZYlpgMA9mZaZg15u6tRUmnE9tOXsTO9ALVmgYBNafjbzH7X66MS0Q2qsqYWeXoDIv07NXisqsYEV7Wy0fEYBy4rkVVUiTUHLzQIKU9/nSqN+xgS6YviSiM8tC4QQuDwhVKsOXABtjkHp3LL8M2B89C5qhHVuRPCfd2lZRsyCyrw7JrDOJhVbPf6H1p7oLt17oQ/3NwdiSfzcPfgMNzUvTOEEKgymvCvzekI8XbDXYPDkJSWj3OFlbipuz86e2gRH+6N2UPCpdeb2DsQH+/KxMwBIXbvo1Ao8OYd/WE0m+HvoW1hZa89hhRqkX/e2R//3pKOd++Nl7b9YVx33P/xPswfG40egZ4I83XD+aIqbLbO+PnDzd0R5uuOLSfzcSJHj9fWnwIABOlckauvBgC8dms/fJ6cheSzhai19tis2puNuweHo2+IpfvxUkkV7v1wL2YMCMbT43tcz49NRB3YC98dxbrUS/juyeH48fAlBHu54ZFRUThxSY/ZH+xBsLcbPrx/kHQNNJuzZZbbzMvlEEJIQabcUIv11mm9ALDP+sfWiGg/9O7ihWWbT+OvPxyze636C2cqFcDACB88OCISnyWfw8GsYigUwC39gzF/bDQm/muHNH7kgeFdMTMuBDPj6sKFQqGAu8YFL0yNkbZdOd7vwRGRdveHRvkh9S8T4eHa8Nd+UwNx5YAhhVrktvhQ3BYfardtdI/OOPDn8fC1npoZGd0ZX+3LBgDcO7QuZKydNxwrks4g9XwJdK5qvDA1Bt8ePA8XlRJT+gah1iyQfNYyKLdviA7HLuoxb9UhrHtyBHw6abA25SIyCyqwam+2FFIMtSZoXbiQHBHVMZrMuFhcha7+nVBrMiPxRB4A4NPd5/B9qqX3Y2KfQPzp61SUVhlRWmXEzHd34d93xyGhm5/0GufKLaFEX12L4kqjdL203RmW3t4QbzdU1NSixHqa56bunXHLgGC8t+OMNEavV5Anyg218HRVQ6NS4OzlCpQZarH/XLE0RkSjUmLD0zchqrMHAGBIV1+pl3lMz4B2q4ucw0hTGFKoXdTvJnxqXHcAAmN7BmBC70Bpu9ZF1aAHZP7N3aWfJ/cJwrTYLgjxdsNjo6Iw8z+7kFVYiUn/2oG7BodJI9fzywzI11fjeI4ej39+EI+NisIC63lXIqJXfzqBz5Kz8MGcQQjUaVFhDQz1Z7/c99FenC+qgr+HBn6dtEjLK8M9H+7Bpw8OwagenXEipwxGc90poEXfHUF6fjn+Mq03tp+2jBWZ0DsQhlqz9EfZTd39oXNVY8aAEGnbfcMi7E4VCSFwobgKb29NxzcHLgAA7h4SJgUUALhlQDD2Zhahe4DHbxpweyNjSKF2F+TliiW3xbb4eRoXJd69p+400kf3D8Y9H+xFfpkBb2/NsNt3++nL0gj3f2/NwIKJPWE0mbE/swj9w7xhNJmRXVSJmC46qJtZnyVPX41nvjmMOwcGt7i9rXU6rwxf7cvGvLHRsjwHTOQoVTUmfLwrE7syCvDIqCiMvaIXIbe0GmtTLqKkqgZPj+sBN40K3+w/j/T8Mjw/uRdcVEqUVRvxzYHzAIB1KRcRF+4tPb+o3sVSbYNO/zazH0b18MfTq1Ox6UQeVu3NxqgenaVeDpuNxy29MQ+u3C9tG9XDEkq+2peNbp07IcLPMublvmHhUkgZ1NXH7nUUCgXCfN2x5LZY1JoEUi+UYN5Y+4vC3jkoDMUVNRjZvXOLa3ijYUgh2eoR6Imdz4/Fh7+exZubTts9tvCKCyOevVyO5749ggNZxbg/IQKncsuwN7MIPu5q/OP2WEzsE9Toe3y0MxM7MwqwM6MArw68Zh/FzttbM/Dj4UsI0rnisdHdrv4EohvQwawiBHi6QueqxsbjuZjePxh3f7AHh8+XAAB2nynEi1Nj8MioKABA6vkS3PPBHuk0ymW9AX8c1x2L1h6FySzQI9ATiSfycPySHtVGy6JlO05fRmVNbaPvDwDjYwIwua/lu2H+zdHYdCIPv6Zfxqs/ncCnu88BsPzxZFuxtT6NixLDovzgrnHBqoeHIsSnbjxLn2AvLJjQAxU1teh5xdINNiqlwu4aa/WpVUq7XmZnxpBCsuaqVuGJMdFYtTcbl0qr4ddJg8J6fw3Z3PneHhSUGwAA36VcRJl1Ol5xpRHzv0rBqoeHYlBXywW3ckqr4OOugatahVTrFyIA/HxeibubaYsQAmZh+XJpi/S8Mms7qtv0OkQdhaHWhJ3pBQjzdUePQE+cuKTHHf9NRpR/Jwzv5o/P92Rh2ebTyCmtRieNCuNiAvHD4Ut4b8cZPDQyEpdKqvDwp/tRWWNCryBPnM4rw3cpF/FdykXpPf687pi0oqpNmaEW26zTeOuL8u8EjYsSL9/SR9rWN9gLAZ5a5JcZpLWdYrzNmBjfDcu3ngFgGX+3cFJP/HjkEmJDvOGusfwKHR7t3+A9/jiOIaM9cJ1ykj2VUoG180bg+cm98PoddaeREqL8MMX6V5AtoACQAkpU506Y0DsQNbVmLPruKADLaaKR/9iGP36VgmqjyS6k7M1XNBkcTuboMfXfOzHq9W0oq7YMkjObRYO1Y67GZBbILKgAYDnVRNTRlFUb8fqGU3azW5pz5nI5RizdirmfHsDs9/eg1mTGhmM5EAI4c7kCn+/JAlAX2m+LD8Vbd/aHl5saBeU1+DX9Mh7/4iAKymvQJ1iHb58YbjerxaZ+QNG6KDEi2k+6H6RzhXu9pRJW3DcQG54eZbdeiFKpwOgedadXHh8VicdjzOgdrJO2jYz2R98QLyyaEoNpsVw9+3pgSKEOIVDniifGdMOQyLrLj88dGWm3OuLQSF8MCPOW7t8U7Y837+gPhQJIzy9HdmEl/rLuGExmgU0n8vB96kXU1JrR2VOLQRHeEFBgnXXkf63JDLN1kYOswgrc+p9dOJmjx8WSKhzIKka10YQZ7+7CpH/tQLXR0v18/FIpks8UNrggY32XSqqkL1OGFPotSquM0mKJrSGEgMksUG004cu9WSiqqMHG47lY8stJVNS7/taO05cxcdl2DHhlE/6TdBbbLikwaflOpGTXjc24UFyJ8f/cjv8kncHCNYcbPQ1SUlmDRd8dxU7rINUPf81EQbml97Owogan88ql5QkaM3tIONQqJSZaB90/8Ml+HL+kh28nDT6YMwgeWhc8fFMUVj44GAPCvPHITZEI1FnGdgXpXLHlmdH48Q8jMXekZQpupH8nfPnIUHS1jhdRKIAIv8YHo94ywDI2LdzXHfPGWE4zda03cHVEIz0mdG3xdA91KJ6uarwyow8Ky2swLiZAWhAJsFw061xhpdQ7ktDND17uavTuosPxS5bpzLbVHAHg+f9ZeleGRflhZDcfHMgqwZqDF+GqccE7WzPQP8wbn88dii/3ZkvnuAHg6IVSHDlfiqMXSwFYLuoVG+qN2/6zG4ZaM4K9XLHmieEN1lwAgIzL5dLPeXpDg8frM5rMcFEqZHnRL7p+nlqdgqS0y1g3b4RdCP8tqmpMmLJ8B7zdNRga6Yv3dpzFiUt6bDiWi8KKGmw4lou1T46AbycN3ttxBqfzLMfnO0lnAKGE0VyJRz47gLVPjkCYrzs+/DVTOm4rakw4lF2MYVF1U3YVAP7+80msOXgBe84WYvOC0dh0PNeuTeuP5uBEI6FLobBMvbX1XEzt1wVrDlpmv6hVCrw9Ow7B9f6fGtMzQJqeq3FR4t1tZ/DwTZHoZp0l0yPQEz/OH4luAZ3grnFBpH8nnMjRI8Tbze57o76bunfGF3OHokeQh7RPhJ87Rkb7w0WlQK+gxseX0LXT4p6UHTt2YPr06QgODoZCocC6deuu+pykpCTEx8dDq9UiOjoaK1eubEVTiSzmJHTFnyb0gEKhQL9QL7hrVPDQumBK3y7SGgcKBTA00vKz7dYWKqb1s++mHd2jMyb3CYRWJXC+uAqvrT8FfXUtfk0vQFFFDf5n/aIcbn3tDcdy8Z+kutlGG4/nYsfpy1IPyaXSanyfehGNOZNfF1Lyy6ql3pr60vPKMOfjfej+4i9YZZ0hcKVqowkHs4pbfLpJCIE1B87jxKXW/2VuU1pltDvNRu2vpLIGO6zTXfda1xC6mj1nC/HC2qOYt+oQEk/mScHddomKX6wBBQCyCivx7y3pqDWZkZpdAsBy4VCjSUjTbwvKazB/1SEYTWbssbZBY50xZxuYuuCbVPRfvAmxizdJwSKzoAJf7MlCYUUNvN3VeNw6SPydbZb/dzQudb9+7h4chsQ/jcYH9w+Sto2I9keUfycEe7ni68cSmu3FeHp8D/wwf4TUe2Jj+X6w/C1u6z1pbNXZ+kZ290eAp6t0X6VU4IuHh2Llg0Ok647R9dPikFJRUYH+/fvj3Xff/U37Z2ZmYtq0aRg7dixSU1Px9NNP4+GHH8bGjRtb3FiiK3m5qbHm8QT874nh8HJXY1iUL27uFYCHRkTCx7rwUv1TREE6V7x1Z3/0CvJEJ40KfxrfA7fFhcBd44KpYWaE+bghvt6Uxbe3pqOwogYBnlppINyJHD0MtWYEeFq6mBNP5GFrWr7UHgDYcjIfZrNArcm+O/ysdTwKABhNAsWV9oOAy6qN+P1H+6RfTLaAdKXXN6Th9hW78d2hxsNQU35NL8DCb49gwTepLXrelYQQuPO/yZi4bAdKq4xteq0bQXFFTbOn+RpTWG6QThU2ZUd6gbS0ev1TPqfzyvDsmsPIKqyw27+s2oj7P96HVXuz8fORHPxlXd2qp0aT5YWKrhh4/uPhSzh+SY+KGhM8tS5YenvdpSj+Mq0XdK4uOHyhFK/+dAJp1kHfz07qYW3fZSzfnI7vDl1EZY1Jmnlj6/x76YfjAIAJMYEYEmk/FffxejPbBnX1RXSAB3SudYuNaVyU2PSnUdjx3FjEh9s/90pqlRKxod7N9jqOiwmEbycNpsdev+UGqO1afLpnypQpmDJlym/e/7///S8iIyPx1ltvAQBiYmKwc+dOLFu2DJMmTWrp2xM10Ce4blyK1kWFjx8YbPd4/ZDy+4QIuKpV+H7+CCgVCmkNFZMJGNNF4PW5N0GtVuORzw4g8UQevrAO6rs1PgT9Q73tXveVGX2w6LujKK404ucjlkGEi2/pg6e/TsXBrGLEvZqI0iojwnzdsO7JEfDz0Nr1pADALe/sQoBOi6W3xaJnkCfe2JgmXSoAsPT+VBtNDbqnt56yrNmw8Xgubh9ovwJwc2x/CafllaGyplb6K7OlcvXV0i+sIxdKcJOTrucghMAHv57Fkl9OYXKfIKy4r+E89oNZRZj3ZQqen9ITt8ZZ/lsdvVCKO/67G7GhXvj60YQm/0JPOlU3duNUrqXelTW1mLhsBwBAAeCNWf2lfXZlFNgNIG0uQN41KAybT+ahsKIG/95iuZZWXIQPxvYMwAMJ4Thz9hzuGxKGAJ0b/vBVCj5Ltvy/ENW5E26NC8Vr60/h2EU90qztWn73AHi5qXH4fCl0bi5Y/OMJ6b1ujQ9Br6C6AagaFyWeHNMNKdnFOHy+BKO6N95L4tLMGkctNTDCBwf/PJ6nTzuYaz4mJTk5GePHj7fbNmnSJDz99NNNPsdgMMBgqOtG1ustf0EYjUYYje33V5vttdrzNTuiG70OnhoFftcvCCdzy3BnfDCMRqOlC1EARrPlL78raxAT6IHEE3nSX5839/CHi8K+V2R0tC9uHRCMj3dnSdsmxvijW+dOOHO5QvoFcb6oCj8evojCcgMOWQchqlUKGE0CF0uqcLGkCtPf2YkPfx8nzXT49IGBeO5/x5BXZsCBzALoXNVYvjUDCyf2gM7VBecKLWNr9pwtRLWhBiqlAtVGE77afwETYgIQ6tNwPAxguZw8AAgBHLtQjH7BOny6Jxujov3R1VdrV4PmHL9QN5jy2IUSDOvqbfe4wWjCih2Z6OLlirsGNR2iMgsqUGU0oXcXXZP7tFZNrRmJJ/MxopsfvBtZDtxQa0bq+RIMjvCxCwlXHgtVNSbszChEVz93dA/0sHuNlclZeG19GgDLaZTd6fkY3NUHPxzOwdrUSwj3dUNOaTVy9dV49acT6BXQCTml1Xgn6SwMtWbsP1eMb/ZnYeaAYGQWVCDUp268RK3JjG1pdSElI78c+aUVWPzjKWnb4fMlWJ6YhqTTl9HVvxMKyizfm7fGBeOHwzlS706PAA9kXC5HmI87sqzjsmJDPKFRAZ/vPY8t1jA0IFSH2tpaPDehGxITz8JkqsXk3p0xsXcANp2w7DM4whverkoMj/LF7rNFMJoExvTwx9Q+lvEhI6J8UFJpxD82nEK10YxXb+mNweGWPyQ6aVWoMJjw6MiuUMGMFfcMgNFkhodWJbvvnxv9e/G3amsd2lo/hWjpSe36T1YosHbtWsycObPJfXr06IEHH3wQixYtkratX78e06ZNQ2VlJdzcGn6Zvvzyy1i8eHGD7atWrYK7u3MvEUzXx9EiBT5Ms/yycHcR+PsgE5QK4H+ZSuzIVWJ8sBnTI8wwmYFP05U4XKREnJ8ZD/QwY8N5BX65oIJOLRDvL5CUo4S7SqDSZPlFGOtrRo0JOFVq/1diiLvAxUoFgtwEFg0wYeVpJVIKlZgaZsLGC0qYhALdPAVu6mLGytN1PStze5oQ1klge44S23KU6OVlxhO9zRACMJoB28zLWjPwf/tUMApLO2ZFWj7T12dViPAQWNCv8VMP5UZgxUkVIj0E7oiyBLUtFxX4IdvywgP9zZjTvS7A1ZiAj9KU0uf7Y59adGskg9SagZcOqWAwAS/Fm+CprtteWgP4uTZ8TkusylBi72Ul4v3MuL9Hw1koX59RYne+EtPCTJgY2vjXYGYZ8N5JFapMCri7CLw4wIT0UgV6eguoFMDLh1SorFWgs6vA5WoFonUCIwPNWJne+MBMBQQE7P+Sd1EIqJVAlUkhHUMAsD1Hge/OqdDJxbI+T5WpYQ+ARilgNKPBaz4RY8KWSwqcLlXCQy3w1zgTyoxASqECP1n/uy2MtczseeuICmbr85+MMaGnd8NaFBuAlw9Z/qa9L9qEwZ0Fas3AyRIFsssVGNXFLP33szlfDpgFEFFvrOlZPXC2TIGxXQTasZOEZKyyshL33HMPSktLodO1/I8RWc7uWbRoERYsWCDd1+v1CAsLw8SJE1v1IZtiNBqRmJiICRMmQK3ueBdeai+sQ8MaxJVW48M0S5f6uN5d8LtplvVZRhlqcSCrGKOi/aW/vqeYzNh1phCxoV7wcddgTE0thqbmYEJMAEoqa5D0TrIUUB69qSuendAdL35/AqcOWsaTBHpqkVdmwMVKyz7De4Vg6tS+KPbLRspPp3CkrBNMwnIKKN+ohsk3GEDdgNqP0lT1BiGakV6mwmltNFbuzkKl0YS/z+iNWQNDkXq+BMa9+6TnqfwjrGtTFCCrXIH44aPx9S878HOeJ2YMCMaToyOhUCjw5+9P4ELFBVyoUGD5w+Ph6eqCrd8eBbItp7j0Sk9MnTpCmjn15d7zOFVa17uUWOiDJ+8c1mARvD1ni1C+9wAAoHOvwbi5p+US9I9+kYKk0wVY/fBgbDieB38PLR4bZRkQmVNaDU9XF3hom//q+vloLvYmW1YlPlqigl/MEJwvrsLM/l3golIip7Qaz+77FYDAzgIt7p0Yh/PFlXDXuCChqxd2bd+KMTePw9vvH0CVyTLuo7JWgbdOuqGowohIP3cMi/JFZe0FdPVzx8oHBmLCv3YiQw8oXHUAKuwWHnRVK1FtNENAATe1ElVGMxaMj8aPR3KQnl+BWms+PF7qgrHjx+J8cSVe/GA/gFo8N6U3fjySgwNZJQCAEG9XLJ4egydWpaKmiSEt8+6chIgjOfi/tccxsW8Ibp3eFwAQfaYQP608CI2LEg/eNhlqlRL9BxVh6YbTqDaa8Pgdw+CmUTX6nRDS5zJ2nynCsxO6S8fbLc3+V+jY+L1o0dY62M6EtNY1DylBQUHIy8uz25aXlwedTtdoLwoAaLVaaLUNr2miVquvycFyrV63o2Ed6moQ5ucC304aFFXUYFxMkFQXH7UaE/q4XfEcYHyfusF4Xmo1HhhhWWMh2Ecg1McNF4qroHVR4rHR0dBoNAjyqnuNJ8dGSwMMASAu3AdqtRojewQCOIULJXVjVCoMtdiZYRlXMijCBweyLKdd6q9XYTILvJt0Vrq/+VQB7hkWidQLlrEDWhclDLVmHMwqkbr+AeCbQzn4OE2JSlMl/rUlAwaTwAPDu0prxwBA6kU9bu4ViNP5dQM2zxZUIq/ciJn/2YPKegNB37gjFq/+dAInc8uwPaMIk/oEYfeZAmQVVuKuQWHYebZI2vdkbjkm9Q3GpuO5SDptWV/j1fVpOG6dhRQd6In8MgNe/uE4fNw1uG9YBKqMJjw+uht8O2kghMDlcgMCPF1xvqgSf/neMh7CRWk5rXbfx5YwlJZXgRemxuDdpEzpVF5pVS3u+ajueiyeri64K0KBgtQ8ZFyugG8nDeaOjMQbG9NQVGHpus4srESm9ZTbvLHR6NpZh7E9A7DpRB7SrbVZfncc5q06hNIqI56d2BOXyw3wdrO8VlFFDYK8XPHo6GicvWzZ/6GV+5Grr8aCb49iW9plmMwC/UK8cF9CJA6dL5VCykcPDEavIB26B5yRpvLePTgMY3sFYN6XhzAzLgQeblrcNSQC4X4e6BvqJR2/I7oH4O7BYegV5Al3V8t37MgegfipRyCEEA3Ga9T/TpjQJxgT+jjfoFN+L1q0tg5trd01DykJCQlYv3693bbExEQkJCRc67cmajWFQoH/m9ILe84WStf2aO3rTO4ThA93ZmLWoFD4WS8oqK83oPHOQWH4x4ZT0syIWOsA3egADyya0gtLfqkbg2AWlnEcSgWw9PZ+eGvTaQyN9MWJHD22nrqMIZE+WH/Usi6F7S/5tNwymMwCX+3Plt7v8z1ZSL9iEO/b284CUCDU2xUXSqrx3+1nkHm5wm4gZvKZQtzUvbM0ANg2tubP646hrN7CYJP6BGLWoDCk55fj/R1n8UPqJRSUG/DndccgBJCaXYKU8/XGtVwsRa3JjNfWn5S2Ha83TfrxLw5JPxdW1GC5daCn2Szw4rQYzF+Vgp+P5mD2kHAcvViCMkMt4sO9MaF3EP6xoa5+K3efw2fJ56QZM7fFh+C7QxehtE5Zv1hSheyiSvwvUwltjmXK7p8m9MDt8SH44NezKKk04qERkbhcbkBOSRW6B3piZlwIAMtCYJtOWP4g8/fQYng3P7x7Tzy2peXjvmERdoOfg7ws57Jc1SppXZCbYwKwam+2tNDZhN6BeGVGH6iUCvx+WAR2nynEggk9pAGoMV10UkhJ6OaHSX2CsPeFcdIMM4VC0WC5dheVEktvb/zinxxQSnLU4pBSXl6OjIy6NSIyMzORmpoKX19fhIeHY9GiRbh48SI+++wzAMDjjz+Od955B8899xweeughbN26Fd988w1+/vnn9vsURNfAnYPCcOegsDa/zoKJPdA3xMsu7NwxMAyfJmdhRLQf3DQqxIZ6Yc/ZIqhVCvTqUncS/7HR3dDF2w1puXqk55VLvwTHxQQiOsDTbjaJEAJZhZXYcCwXLkolPn1oCH739k5cLKnCtwfP4+zlCni5qbFwck/8cPiSNLC3f6gXDl+wrCHjqxX4+tGh+L+1x/FregE2WBfiujUuBGtTLmL3mULszChAjcmMThoV+od5Y/eZQiRZr4+SEGUZpGq7Jsot/YPx/o6z+PloDn6ut4z619ar1Nocu6jHL8dypQHBjdG4KPHUuO4QQuCtxNMQAth6Kh/+nlrptW1XntW5umD53XFwVavw3o4z8HHXYMaAYLyzNQO1ZoEATy2eHNMN9w/viun9gxHm44boAE9UG00Y8vfNKK6uBWoM8HFX485BodC6qPDefQNx9GIpHhjetdFZJ+N6BaKTRoWKGhPG9OwMpVKBkd39MbKJmStXGt2jM1bttbR/SFdfvP/7gVJwGNTVF/tftJ+AEFPvOLEtqObHq2rTDabFIeXAgQMYO3asdN82duT+++/HypUrkZOTg+zsuvPlkZGR+Pnnn/GnP/0Jy5cvR2hoKD788ENOPyan4a5xkf7atukX6oVtz46RlvOOD/fBnrNF6N1FB62L/aDLW/oHA/2DsXxzuhRS7k/o2uB9FAoFuvp3wsoHh6CTVoW+IV4I0rkiV18tra77wPCu0Lmq8fEDg7A25SIulxnw52m9MefjfaiqqcXDURUI8NTi3qER+NW6rHl0gAf+b0ovrE25iOOX9HjwE8upkeHR/rhvWAT2ZhbBZBYI93XHFw8PtRt70idYJ7UBsFzKYECYN/7y/TGUVBrRP8wbRy6UIFdfLfV4PDWuO9YcOI9L1mu5rP/jTTALgeiAulVA5wzvivhXEnG2oAJLrT1Ndw0Kw6/plxEX7oOFk3oizLqc+a/PjYVapYSrWoW5IyNRbqhFZw+tFDTGWlctBSw9G9Nju+DLfZYQdcfAUOm/x9AoPwyNqrsezJXcNCrcNTgcn+zOxG3xIU3u15QR0f5wU6tQZTThL7/rfdWejYERlrVDegV5IlDXxlHGRDLV4pAyZsyYZle5bGw12TFjxiAlJaWlb0V0Q6u/8uUdA0Ox6UQe7h/etcn9x8UEYNnm0wjzdbO7eNqVRtW7SFqPIE8pIKhVCun1B0b4YmBE3foxm/40CgZDDbYkWhZZHB8TIIWLh0ZEIlDnallnwjoG5vb4UPz1d73h5a5G0rNj8PX+85jcN6jB4FiFQoGHRnbFa+tPYULvQLw4NQZKpQIT+wQiNbsE3QM9Meu/u3HmcgUuFFdB46LEfcMikFVYgXWplxDu6253gTcbnasag7v6Itm67suoHp2x9PZ+jf5i96y3QJinq9rufmPuiA+RQsrdQ8Kb3fdKL06Lwbyx3VrVo+GhdcGqR4aiptZsd02qpsSF++Cj+wchqrPHVfcl6qhkObuHyNlEdfbA5gWjm92nb4gXvp83AsHebr95/ECvIE9p9dqxPQPga12F90pqlRKoN2bCRaXEivvicTCrGHcNtpzy+nzuEJwrqESQl6vd64T5uuPZST2bbMPckVEY1NUXsSFe0oworYtK6pUYGe2PM9bBo/cnRKCzpxaT+3bButRLmN6/6SvNjuzuL4WUF6fGtNuYij7BnpgZYUJcbF/pOjC/lUqpaNMpl7irrKx6pXExga1+L6KOgCGFqAPp38ILzEXX+yXb0lMQceE+dr803TUujfZqXI1KqWh2WfNFU2Pwu/7BCPF2ky4gN7lvEHYsHItg76ZPY9w5KAwbj+diYu9A9GzHC78pFAqMDRaYOqTt45GIqG0YUohuYH1D6k4bjO0V0MyejuOqVmFwV98G28P9ml+4sbOnFj/MH3mtmkVEMsCQQnQD6x2sw4p74xHi49ZgQC4RkdwxpBDd4Kb0a3pcBxGRnPHqCURERCRLDClEREQkSwwpREREJEsMKURERCRLDClEREQkSwwpREREJEsMKURERCRLDClEREQkSwwpREREJEsMKURERCRLDClEREQkSwwpREREJEsMKURERCRLHeIqyEIIAIBer2/X1zUajaisrIRer4darW7X1+5IWAfWAGANbFgH1gBgDWzaWgfb723b7/GW6hAhpaysDAAQFhbm4JYQERFRS5WVlcHLy6vFz1OI1sab68hsNuPSpUvw9PSEQqFot9fV6/UICwvD+fPnodPp2u11OxrWgTUAWAMb1oE1AFgDm7bWQQiBsrIyBAcHQ6ls+QiTDtGTolQqERoaes1eX6fTOfVBaMM6sAYAa2DDOrAGAGtg05Y6tKYHxYYDZ4mIiEiWGFKIiIhIlpw6pGi1Wrz00kvQarWObopDsQ6sAcAa2LAOrAHAGtg4ug4dYuAsEREROR+n7kkhIiIi+WJIISIiIlliSCEiIiJZYkghIiIiWXLqkPLuu++ia9eucHV1xdChQ7Fv3z5HN6lVlixZgsGDB8PT0xMBAQGYOXMm0tLS7PYZM2YMFAqF3b/HH3/cbp/s7GxMmzYN7u7uCAgIwMKFC1FbW2u3T1JSEuLj46HVahEdHY2VK1de64/3m7388ssNPmOvXr2kx6urqzFv3jz4+fnBw8MDt99+O/Ly8uxeo6PXoGvXrg1qoFAoMG/ePAA35nGwY8cOTJ8+HcHBwVAoFFi3bp3d40II/PWvf0WXLl3g5uaG8ePHIz093W6foqIi3HvvvdDpdPD29sbcuXNRXl5ut8+RI0dw0003wdXVFWFhYXj99dcbtGXNmjXo1asXXF1d0a9fP6xfv77dP29TmquD0WjE888/j379+qFTp04IDg7GnDlzcOnSJbvXaOz4Wbp0qd0+cq7D1Y6FBx54oMHnmzx5st0+Hf1YuFoNGvt+UCgUeOONN6R9ZHUcCCe1evVqodFoxMcffyyOHz8uHnnkEeHt7S3y8vIc3bQWmzRpkvjkk0/EsWPHRGpqqpg6daoIDw8X5eXl0j6jR48WjzzyiMjJyZH+lZaWSo/X1taKvn37ivHjx4uUlBSxfv164e/vLxYtWiTtc/bsWeHu7i4WLFggTpw4Id5++22hUqnEhg0bruvnbcpLL70k+vTpY/cZL1++LD3++OOPi7CwMLFlyxZx4MABMWzYMDF8+HDp8RuhBvn5+XafPzExUQAQ27ZtE0LcmMfB+vXrxYsvvii+++47AUCsXbvW7vGlS5cKLy8vsW7dOnH48GFxyy23iMjISFFVVSXtM3nyZNG/f3+xZ88e8euvv4ro6Ggxe/Zs6fHS0lIRGBgo7r33XnHs2DHx1VdfCTc3N/Hee+9J++zatUuoVCrx+uuvixMnTog///nPQq1Wi6NHj17zGgjRfB1KSkrE+PHjxddffy1OnTolkpOTxZAhQ8TAgQPtXiMiIkK88sordsdH/e8RudfhasfC/fffLyZPnmz3+YqKiuz26ejHwtVqUP+z5+TkiI8//lgoFApx5swZaR85HQdOG1KGDBki5s2bJ903mUwiODhYLFmyxIGtah/5+fkCgNi+fbu0bfTo0eKpp55q8jnr168XSqVS5ObmSttWrFghdDqdMBgMQgghnnvuOdGnTx+75911111i0qRJ7fsBWumll14S/fv3b/SxkpISoVarxZo1a6RtJ0+eFABEcnKyEOLGqMGVnnrqKdGtWzdhNpuFEDf+cXDll7LZbBZBQUHijTfekLaVlJQIrVYrvvrqKyGEECdOnBAAxP79+6V9fvnlF6FQKMTFixeFEEL85z//ET4+PlINhBDi+eefFz179pTu33nnnWLatGl27Rk6dKh47LHH2vUz/haN/XK60r59+wQAkZWVJW2LiIgQy5Yta/I5HakOTYWUGTNmNPmcG+1Y+C3HwYwZM8TNN99st01Ox4FTnu6pqanBwYMHMX78eGmbUqnE+PHjkZyc7MCWtY/S0lIAgK+vr932L7/8Ev7+/ujbty8WLVqEyspK6bHk5GT069cPgYGB0rZJkyZBr9fj+PHj0j71a2bbR041S09PR3BwMKKionDvvfciOzsbAHDw4EEYjUa79vfq1Qvh4eFS+2+UGtjU1NTgiy++wEMPPWR3YU5nOA5sMjMzkZuba9deLy8vDB061O6/u7e3NwYNGiTtM378eCiVSuzdu1faZ9SoUdBoNNI+kyZNQlpaGoqLi6V9OkpdAMv3hEKhgLe3t932pUuXws/PD3FxcXjjjTfsTvXdCHVISkpCQEAAevbsiSeeeAKFhYXSY852LOTl5eHnn3/G3LlzGzwml+OgQ1xgsL0VFBTAZDLZfREDQGBgIE6dOuWgVrUPs9mMp59+GiNGjEDfvn2l7ffccw8iIiIQHByMI0eO4Pnnn0daWhq+++47AEBubm6j9bA91tw+er0eVVVVcHNzu5Yf7aqGDh2KlStXomfPnsjJycHixYtx00034dixY8jNzYVGo2nwhRwYGHjVz2d7rLl95FKD+tatW4eSkhI88MAD0jZnOA7qs7W5sfbW/zwBAQF2j7u4uMDX19dun8jIyAavYXvMx8enybrYXkNOqqur8fzzz2P27Nl2F4374x//iPj4ePj6+mL37t1YtGgRcnJy8M9//hNAx6/D5MmTcdtttyEyMhJnzpzBCy+8gClTpiA5ORkqlcrpjoVPP/0Unp6euO222+y2y+k4cMqQciObN28ejh07hp07d9ptf/TRR6Wf+/Xrhy5dumDcuHE4c+YMunXrdr2beU1MmTJF+jk2NhZDhw5FREQEvvnmG1n94rxePvroI0yZMgXBwcHSNmc4Dqh5RqMRd955J4QQWLFihd1jCxYskH6OjY2FRqPBY489hiVLltwQy8Pffffd0s/9+vVDbGwsunXrhqSkJIwbN86BLXOMjz/+GPfeey9cXV3ttsvpOHDK0z3+/v5QqVQNZnbk5eUhKCjIQa1qu/nz5+Onn37Ctm3bEBoa2uy+Q4cOBQBkZGQAAIKCghqth+2x5vbR6XSyDAHe3t7o0aMHMjIyEBQUhJqaGpSUlNjtU/+/+Y1Ug6ysLGzevBkPP/xws/vd6MeBrc3N/b8eFBSE/Px8u8dra2tRVFTULseGnL5TbAElKysLiYmJdr0ojRk6dChqa2tx7tw5ADdOHWyioqLg7+9vd/w7y7Hw66+/Ii0t7arfEYBjjwOnDCkajQYDBw7Eli1bpG1msxlbtmxBQkKCA1vWOkIIzJ8/H2vXrsXWrVsbdMM1JjU1FQDQpUsXAEBCQgKOHj1q9z+o7Uusd+/e0j71a2bbR641Ky8vx5kzZ9ClSxcMHDgQarXarv1paWnIzs6W2n8j1eCTTz5BQEAApk2b1ux+N/pxEBkZiaCgILv26vV67N271+6/e0lJCQ4ePCjts3XrVpjNZinEJSQkYMeOHTAajdI+iYmJ6NmzJ3x8fKR95FwXW0BJT0/H5s2b4efnd9XnpKamQqlUSqdAboQ61HfhwgUUFhbaHf/OcCwAlp7WgQMHon///lfd16HHQYuG2d5AVq9eLbRarVi5cqU4ceKEePTRR4W3t7fdrIaO4oknnhBeXl4iKSnJbspYZWWlEEKIjIwM8corr4gDBw6IzMxM8f3334uoqCgxatQo6TVsU08nTpwoUlNTxYYNG0Tnzp0bnXq6cOFCcfLkSfHuu+/KavrtM888I5KSkkRmZqbYtWuXGD9+vPD39xf5+flCCMsU5PDwcLF161Zx4MABkZCQIBISEqTn3wg1EMIyUy08PFw8//zzdttv1OOgrKxMpKSkiJSUFAFA/POf/xQpKSnSrJWlS5cKb29v8f3334sjR46IGTNmNDoFOS4uTuzdu1fs3LlTdO/e3W7aaUlJiQgMDBS///3vxbFjx8Tq1auFu7t7gymXLi4u4s033xQnT54UL7300nWdgtxcHWpqasQtt9wiQkNDRWpqqt33hG2Gxu7du8WyZctEamqqOHPmjPjiiy9E586dxZw5czpMHZqrQVlZmXj22WdFcnKyyMzMFJs3bxbx8fGie/fuorq6WnqNjn4sXO3/ByEsU4jd3d3FihUrGjxfbseB04YUIYR4++23RXh4uNBoNGLIkCFiz549jm5SqwBo9N8nn3wihBAiOztbjBo1Svj6+gqtViuio6PFwoUL7dbHEEKIc+fOiSlTpgg3Nzfh7+8vnnnmGWE0Gu322bZtmxgwYIDQaDQiKipKeg85uOuuu0SXLl2ERqMRISEh4q677hIZGRnS41VVVeLJJ58UPj4+wt3dXdx6660iJyfH7jU6eg2EEGLjxo0CgEhLS7PbfqMeB9u2bWv0+L///vuFEJZpyH/5y19EYGCg0Gq1Yty4cQ1qU1hYKGbPni08PDyETqcTDz74oCgrK7Pb5/Dhw2LkyJFCq9WKkJAQsXTp0gZt+eabb0SPHj2ERqMRffr0ET///PM1+9xXaq4OmZmZTX5P2NbQOXjwoBg6dKjw8vISrq6uIiYmRrz22mt2v8CFkHcdmqtBZWWlmDhxoujcubNQq9UiIiJCPPLIIw3+MO3ox8LV/n8QQoj33ntPuLm5iZKSkgbPl9txoBBCiJb1vRARERFde045JoWIiIjkjyGFiIiIZIkhhYiIiGSJIYWIiIhkiSGFiIiIZIkhhYiIiGSJIYWIiIhkiSGFiIiIZIkhhYiIiGSJIYWIiIhkiSGFiIiIZIkhhYiIiGTp/wEoX0lQG6yX/AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot([i[\"step\"] for i in record_dict[\"train\"][::50]], [i[\"loss\"] for i in record_dict[\"train\"][::50]],\n",
    "         label=\"train\")\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "869bd29321bc0af7",
   "metadata": {},
   "source": [
    "## 推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "efc609cabc2d694a",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T16:19:56.180318Z",
     "iopub.status.busy": "2025-01-23T16:19:56.179879Z",
     "iopub.status.idle": "2025-01-23T16:19:57.024341Z",
     "shell.execute_reply": "2025-01-23T16:19:57.023904Z",
     "shell.execute_reply.started": "2025-01-23T16:19:56.180296Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_353/27374796.py:28: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(\"checkpoints/03_text_generation.ckpt\", map_location=device))\n",
      "  0%|          | 0/1000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "this same papers,\n",
      "Will have my master died the town.\n",
      "\n",
      "KING EDWARD IV:\n",
      "But seek to be revengeful to revenge\n",
      "That I have said to he"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 124/1000 [00:00<00:00, 1234.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "aven,\n",
      "It would not stay all reason.\n",
      "\n",
      "KING RICHARD III:\n",
      "Why should they have publied.\n",
      "\n",
      "PETRUCHIO:\n",
      "Come, sir, away!\n",
      "\n",
      "DUKE VINCENTIO"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|██▌       | 253/1000 [00:00<00:00, 1266.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ":\n",
      "What is his order his throne.\n",
      "\n",
      "GLOUCESTER:\n",
      "My lord, here comes the field\n",
      "What you do well deliver me.\n",
      "\n",
      "TRANIO:\n",
      "Ay, wherefore had not, s"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 39%|███▉      | 390/1000 [00:00<00:00, 1309.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ir, to do this best drown'd in this garden's blood,\n",
      "But that you have stood away.\n",
      "\n",
      "KING RICHARD III:\n",
      "Here is the matter?\n",
      "\n",
      "COMINIUS:"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 521/1000 [00:00<00:00, 1228.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Here's one with him to this place;\n",
      "This is he?\n",
      "\n",
      "MARIANA:\n",
      "Go to, go, bear them make\n",
      "Makes me with her to the walls,\n",
      "And set dow"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 65%|██████▍   | 648/1000 [00:00<00:00, 1240.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n to him, it is too few, a pair of blood upon\n",
      "A means the same instrument.\n",
      "\n",
      "DUKE OF YORK:\n",
      "What is the master?\n",
      "\n",
      "CLAUDIO:\n",
      "The m"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 77%|███████▋  | 773/1000 [00:00<00:00, 1156.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ost man better than the city, and therefore then\n",
      "And that the wolf so much\n",
      "For visiting him, and let them banished wi"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 89%|████████▉ | 890/1000 [00:00<00:00, 1141.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "th thee.\n",
      "\n",
      "MIRANDA:\n",
      "O honour, that thou canst deliver,\n",
      "The belly answer the like.\n",
      "\n",
      "LUCIO:\n",
      "He that we will have "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:00<00:00, 1202.46it/s]\n"
     ]
    }
   ],
   "source": [
    "def generate_text(model, start_string, max_len=1000, temperature=1.0, stream=True):\n",
    "    # tempareture: 控制随机性，越高越随机，越低越相似\n",
    "    # stream: 是否流式生成，即一次生成一部分字符，返回一个字符，再输入下一个字符，直到达到max_len\n",
    "    input_eval = torch.tensor([char2idx[s] for s in start_string]).to(dtype=torch.int64, device=device).reshape(1, -1)\n",
    "    hidden = None\n",
    "    text_generated = []  # 用来保存生成的文本\n",
    "    model.eval()\n",
    "    pbar = tqdm(range(max_len))  # 进度条\n",
    "    print(start_string, end=\"\")\n",
    "    # no_grad是一个上下文管理器，用于指定在其中的代码块中不需要计算梯度。在这个区域内，不会记录梯度信息，用于在生成文本时不影响模型权重\n",
    "    with torch.no_grad():\n",
    "        for i in pbar:  # 控制进度条\n",
    "            logits, hidden = model(input_eval, hidden=hidden)\n",
    "            # 温度采样，较高的温度会增加预测结果的多样性，较低的温度则更加保守。\n",
    "            # 取-1的目的是只要最后，拼到原有的输入上\n",
    "            logits = logits[0, -1, :] / temperature\n",
    "            probs = F.softmax(logits, dim=-1)  # 算为概率分布\n",
    "            # 从概率分布中抽取一个样本,取概率较大的那些\n",
    "            idx = torch.multinomial(probs, 1).item()\n",
    "            # 把idx转为tensor\n",
    "            input_eval = torch.Tensor([idx]).to(dtype=torch.int64, device=device).reshape(1, -1)\n",
    "            text_generated.append(idx)\n",
    "            if stream:\n",
    "                print(idx2char[idx], end=\"\", flush=True)\n",
    "    return \"\".join([idx2char[i] for i in text_generated])\n",
    "\n",
    "# 03_text_generation.ckpt\n",
    "model.load_state_dict(torch.load(\"checkpoints/03_text_generation.ckpt\", map_location=device))\n",
    "start_string = \"this \" #这里就是开头，什么都可以\n",
    "res = generate_text(model, start_string, max_len=1000, temperature=0.5, stream=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1be041d2e307e3eb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T16:20:05.083340Z",
     "iopub.status.busy": "2025-01-23T16:20:05.082901Z",
     "iopub.status.idle": "2025-01-23T16:20:05.086583Z",
     "shell.execute_reply": "2025-01-23T16:20:05.086114Z",
     "shell.execute_reply.started": "2025-01-23T16:20:05.083320Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"same papers,\\nWill have my master died the town.\\n\\nKING EDWARD IV:\\nBut seek to be revengeful to revenge\\nThat I have said to heaven,\\nIt would not stay all reason.\\n\\nKING RICHARD III:\\nWhy should they have publied.\\n\\nPETRUCHIO:\\nCome, sir, away!\\n\\nDUKE VINCENTIO:\\nWhat is his order his throne.\\n\\nGLOUCESTER:\\nMy lord, here comes the field\\nWhat you do well deliver me.\\n\\nTRANIO:\\nAy, wherefore had not, sir, to do this best drown'd in this garden's blood,\\nBut that you have stood away.\\n\\nKING RICHARD III:\\nHere is the matter?\\n\\nCOMINIUS:\\nHere's one with him to this place;\\nThis is he?\\n\\nMARIANA:\\nGo to, go, bear them make\\nMakes me with her to the walls,\\nAnd set down to him, it is too few, a pair of blood upon\\nA means the same instrument.\\n\\nDUKE OF YORK:\\nWhat is the master?\\n\\nCLAUDIO:\\nThe most man better than the city, and therefore then\\nAnd that the wolf so much\\nFor visiting him, and let them banished with thee.\\n\\nMIRANDA:\\nO honour, that thou canst deliver,\\nThe belly answer the like.\\n\\nLUCIO:\\nHe that we will have \""
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
